package com.lhy.hierarchy.ruleEngine.composite.utils;

import com.hauxsoft.core.util.result.Result;
import com.hauxsoft.core.util.result.Results;
import com.hauxsoft.data.CompileResult;
import com.hauxsoft.entity.JavaRuleDO;
import lombok.extern.slf4j.Slf4j;

import javax.tools.*;
import java.io.*;
import java.net.JarURLConnection;
import java.net.URI;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.CharBuffer;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 规则工具类
 * @author z_hh  
 * @date 2018年12月7日
 */
@Slf4j
public class DynamicRuleUtils {
    private static final String CLASS_SUFFIX = ".class";
    private static final String CLASS_FILE_PREFIX = File.separator + "classes"  + File.separator;
    private static final String PACKAGE_SEPARATOR = ".";
    
    /**
     * 查找包下的所有类的名字
     * @param packageName
     * @param showChildPackageFlag 是否需要显示子包内容
     * @return List集合，内容为类的全名
     */
    public static List<String> getClazzName(String packageName, boolean showChildPackageFlag ) {
        List<String> result = new ArrayList<>();
        String suffixPath = packageName.replaceAll("\\.", "/");
        ClassLoader loader = Thread.currentThread().getContextClassLoader();
        try {
            Enumeration<URL> urls = loader.getResources(suffixPath);
            while(urls.hasMoreElements()) {
                URL url = urls.nextElement();
                if(url != null) {
                    String protocol = url.getProtocol();
                    if("file".equals(protocol)) {
                        String path = url.getPath();
                        result.addAll(getAllClassNameByFile(new File(path), showChildPackageFlag));
                    } else if("jar".equals(protocol)) {
                        JarFile jarFile = null;
                        try{
                            jarFile = ((JarURLConnection) url.openConnection()).getJarFile();
                        } catch(Exception e){
                            e.printStackTrace();
                        }
                        if(jarFile != null) {
                            result.addAll(getAllClassNameByJar(jarFile, packageName, showChildPackageFlag));
                        }
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return result;
    }
    
    
    /**
     * 递归获取所有class文件的名字
     * @param file 
     * @param flag  是否需要迭代遍历
     * @return List
     */
    private static List<String> getAllClassNameByFile(File file, boolean flag) {
        List<String> result =  new ArrayList<>();
        if(!file.exists()) {
            return result;
        }
        if(file.isFile()) {
            String path = file.getPath();
            // 注意：这里替换文件分割符要用replace。因为replaceAll里面的参数是正则表达式,而windows环境中File.separator="\\"的,因此会有问题
            if(path.endsWith(CLASS_SUFFIX)) {
                path = path.replace(CLASS_SUFFIX, "");
                // 从"/classes/"后面开始截取
                String clazzName = path.substring(path.indexOf(CLASS_FILE_PREFIX) + CLASS_FILE_PREFIX.length())
                        .replace(File.separator, PACKAGE_SEPARATOR);
                if(-1 == clazzName.indexOf("$")) {
                    result.add(clazzName);
                }
            }
            return result;
            
        } else {
            File[] listFiles = file.listFiles();
            if(listFiles != null && listFiles.length > 0) {
                for (File f : listFiles) {
                    if(flag) {
                        result.addAll(getAllClassNameByFile(f, flag));
                    } else {
                        if(f.isFile()){
                            String path = f.getPath();
                            if(path.endsWith(CLASS_SUFFIX)) {
                                path = path.replace(CLASS_SUFFIX, "");
                                // 从"/classes/"后面开始截取
                                String clazzName = path.substring(path.indexOf(CLASS_FILE_PREFIX) + CLASS_FILE_PREFIX.length())
                                        .replace(File.separator, PACKAGE_SEPARATOR);
                                if(-1 == clazzName.indexOf("$")) {
                                    result.add(clazzName);
                                }
                            }
                        }
                    }
                }
            } 
            return result;
        }
    }
    
    
    /**
     * 递归获取jar所有class文件的名字
     * @param jarFile 
     * @param packageName 包名
     * @param flag  是否需要迭代遍历
     * @return List
     */
    private static List<String> getAllClassNameByJar(JarFile jarFile, String packageName, boolean flag) {
        List<String> result =  new ArrayList<>();
        Enumeration<JarEntry> entries = jarFile.entries();
        while(entries.hasMoreElements()) {
            JarEntry jarEntry = entries.nextElement();
            String name = jarEntry.getName();
            // 判断是不是class文件
            if(name.endsWith(CLASS_SUFFIX)) {
                name = name.replace(CLASS_SUFFIX, "").replace("/", ".");
                if(flag) {
                    // 如果要子包的文件,那么就只要开头相同且不是内部类就ok
                    if(name.startsWith(packageName) && -1 == name.indexOf("$")) {
                        result.add(name);
                    }
                } else {
                    // 如果不要子包的文件,那么就必须保证最后一个"."之前的字符串和包名一样且不是内部类
                    if(packageName.equals(name.substring(0, name.lastIndexOf("."))) && -1 == name.indexOf("$")) {
                        result.add(name);
                    }
                }
            }
        }
        return result;
    }
    
    // 编译版本
    private static final String TARGET_CLASS_VERSION = "1.8";

	/**
	 * auto fill in the java-name with code, return null if cannot find the public class
	 * 
	 * @param javaSrc source code string
	 * @return return the Map, the KEY means ClassName, the VALUE means bytecode.
	 * @throws RuntimeException
	 */
	public static Result<CompileResult> compile(String javaSrc) throws RuntimeException {
		Pattern pattern = Pattern.compile("public\\s+class\\s+(\\w+)");
		Matcher matcher = pattern.matcher(javaSrc);
		if (matcher.find()) {
			return compile(matcher.group(1) + ".java", javaSrc);
		}
		return Results.error("找不到类名称！");
	}

	/**
	 * @param javaName the name of your public class,eg: <code>TestClass.java</code>
	 * @param javaSrc source code string
	 * @return return the Map, the KEY means ClassName, the VALUE means bytecode.
	 * @throws RuntimeException
	 */
	public static Result<CompileResult> compile(String javaName, String javaSrc) throws RuntimeException {
		JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
		StandardJavaFileManager stdManager = compiler.getStandardFileManager(null, null, null);
		try (MemoryJavaFileManager manager = new MemoryJavaFileManager(stdManager)) {
			JavaFileObject javaFileObject = MemoryJavaFileManager.makeStringSource(javaName, javaSrc);
			DiagnosticCollector<JavaFileObject> collector = new DiagnosticCollector<>();
			List<String> options = new ArrayList<>();
	        options.add("-target");
	        options.add(TARGET_CLASS_VERSION);
			JavaCompiler.CompilationTask task = compiler.getTask(null, manager, collector, options, null,
					Arrays.asList(javaFileObject));
			if (task.call()) {
				return Results.success(CompileResult.builder()
						.mainClassFileName(javaName)
						.byteCode(manager.getClassBytes())
						.build());
			}
			String errorMessage = collector.getDiagnostics().stream()
				.map(diagnostics -> diagnostics.toString())
				.reduce("", (s1, s2) -> s1 + "\n" +s2);
			return Results.error(errorMessage);
		} catch (IOException e) {
			log.error("编译出错啦！", e);
			return Results.error(e.getMessage());
		}
	}
	
	/**
	 * 从JavaRuleDO获取规则Class对象
	 * @param javaRule
	 * @throws Exception
	 */
	public static Class<?> getRuleClass(JavaRuleDO javaRule) throws Exception {
		Map<String, byte[]> bytecode = new HashMap<>();
		String fullName = Objects.requireNonNull(javaRule.getFullClassName(), "全类名不能为空！");
		bytecode.put(fullName, javaRule.getByteContent());
		try (MemoryClassLoader classLoader = new MemoryClassLoader(bytecode)) {
			return classLoader.findClass(fullName);
		} catch (Exception e) {
			log.error("加载类{}异常！", fullName);
			throw e;
		}
	}
	
	/**
	 * 从JavaRuleDO获取规则实例对象
	 * @param javaRule
	 * @throws Exception
	 */
	public static BaseRule getRuleInstance(JavaRuleDO javaRule) throws Exception {
		try {
			BaseRule rule = (BaseRule) getRuleClass(javaRule).newInstance();
			// 设置优先级
			rule.setPriority(javaRule.getSort());
			return rule;
		} catch (Exception e) {
			log.error("从JavaRuleDO获取规则实例异常！", e);
			throw e;
		}
	}
	
	/*
	 * a temp memory class loader
	 */
    private static class MemoryClassLoader extends URLClassLoader {
		Map<String, byte[]> classBytes = new HashMap<String, byte[]>();

		public MemoryClassLoader(Map<String, byte[]> classBytes) {
			super(new URL[0], MemoryClassLoader.class.getClassLoader());
			this.classBytes.putAll(classBytes);
		}

		@Override
		protected Class<?> findClass(String name) throws ClassNotFoundException {
			byte[] buf = classBytes.get(name);
			if (buf == null) {
				return super.findClass(name);
			}
			classBytes.remove(name);
			return defineClass(name, buf, 0, buf.length);
		}
	}
    
}

/**
 * JavaFileManager that keeps compiled .class bytes in memory.
 */
@SuppressWarnings({ "unchecked", "rawtypes" })
final class MemoryJavaFileManager extends ForwardingJavaFileManager {
	/**
	 * Java source file extension.
	 */
	private final static String EXT = ".java";
	private Map<String, byte[]> classBytes;

	public MemoryJavaFileManager(JavaFileManager fileManager) {
		super(fileManager);
		classBytes = new HashMap<String, byte[]>();
	}

	public Map<String, byte[]> getClassBytes() {
		return classBytes;
	}

	@Override
	public void close() throws IOException {
		classBytes = new HashMap<String, byte[]>();
	}

	@Override
	public void flush() throws IOException {
	}

	/**
	 * A file object used to represent Java source coming from a string.
	 */
	private static class StringInputBuffer extends SimpleJavaFileObject {
		final String code;

		StringInputBuffer(String name, String code) {
			super(toURI(name), Kind.SOURCE);
			this.code = code;
		}

		public CharBuffer getCharContent(boolean ignoreEncodingErrors) {
			return CharBuffer.wrap(code);
		}

		@SuppressWarnings("unused")
		public Reader openReader() {
			return new StringReader(code);
		}
	}

	/**
	 * A file object that stores Java bytecode into the classBytes map.
	 */
	private class ClassOutputBuffer extends SimpleJavaFileObject {
		private String name;

		ClassOutputBuffer(String name) {
			super(toURI(name), Kind.CLASS);
			this.name = name;
		}

		public OutputStream openOutputStream() {
			return new FilterOutputStream(new ByteArrayOutputStream()) {
				public void close() throws IOException {
					out.close();
					ByteArrayOutputStream bos = (ByteArrayOutputStream) out;
					classBytes.put(name, bos.toByteArray());
				}
			};
		}
	}

	@Override
	public JavaFileObject getJavaFileForOutput(Location location, String className,
                                               JavaFileObject.Kind kind, FileObject sibling) throws IOException {
		if (kind == JavaFileObject.Kind.CLASS) {
			return new ClassOutputBuffer(className);
		} else {
			return super.getJavaFileForOutput(location, className, kind, sibling);
		}
	}

	static JavaFileObject makeStringSource(String name, String code) {
		return new StringInputBuffer(name, code);
	}

	static URI toURI(String name) {
		File file = new File(name);
		if (file.exists()) {
			return file.toURI();
		} else {
			try {
				final StringBuilder newUri = new StringBuilder();
				newUri.append("mfm:///");
				newUri.append(name.replace('.', '/'));
				if (name.endsWith(EXT))
					newUri.replace(newUri.length() - EXT.length(), newUri.length(), EXT);
				return URI.create(newUri.toString());
			} catch (Exception exp) {
				return URI.create("mfm:///com/sun/script/java/java_source");
			}
		}
	}
}